15. Linear Regression in scikit-learn
Linear Regression
In this section, you'll use linear regression to predict life expectancy from body mass index (BMI) . Before you do that, let's go over the tools required to build this model.
For your linear regression model, you'll be using scikit-learn's
LinearRegression
class. This class provides the function
fit()
to fit the model to your data.
>>> from sklearn.linear_model import LinearRegression
>>> model = LinearRegression()
>>> model.fit(x_values, y_values)
In the example above, the
model
variable is a linear regression model that has been fitted to the data
x_values
and
y_values
. Fitting the model means finding the best line that fits the training data. Let's make two predictions using the model's
predict()
function.
>>> print(model.predict([ [127], [248] ]))
[[ 438.94308857, 127.14839521]]
The model returned an array of predictions, one prediction for each input array. The first input,
[127]
, got a prediction of
438.94308857
. The second input,
[248]
, got a prediction of
127.14839521
. The reason for predicting on an array like
[127]
and not just
127
, is because you can have a model that makes a prediction using multiple features. We'll go over using multiple variables in linear regression later in this lesson. For now, let's stick to a single value.
Linear Regression Quiz
In this quiz, you'll be working with data on the average life expectancy at birth and the average BMI for males across the world. The data comes from Gapminder .
The data file can be found under the "bmi_and_life_expectancy.csv" tab in the quiz below. It includes three columns, containing the following data:
- Country – The country the person was born in.
- Life expectancy – The average life expectancy at birth for a person in that country.
- BMI – The mean BMI of males in that country.
You'll need to complete each of the following steps:
1. Load the data
- The data is in the file called "bmi_and_life_expectancy.csv".
-
Use pandas
read_csv
to load the data into a dataframe (don't forget to import pandas!) -
Assign the dataframe to the variable
bmi_life_data
.
2. Build a linear regression model
-
Create a regression model using scikit-learn's
LinearRegression
and assign it tobmi_life_model
. - Fit the model to the data.
3. Predict using the model
-
Predict using a BMI of 21.07931 and assign it to the variable
laos_life_exp
.
Start Quiz:
# TODO: Add import statements
# Assign the dataframe to this variable.
# TODO: Load the data
bmi_life_data = None
# Make and fit the linear regression model
#TODO: Fit the model and Assign it to bmi_life_model
bmi_life_model = None
# Make a prediction using the model
# TODO: Predict life expectancy for a BMI value of 21.07931
laos_life_exp = None
Country,Life expectancy,BMI
Afghanistan,52.8,20.62058
Albania,76.8,26.44657
Algeria,75.5,24.59620
Andorra,84.6,27.63048
Angola,56.7,22.25083
Armenia,72.3,25.355420000000002
Australia,81.6,27.56373
Austria,80.4,26.467409999999997
Azerbaijan,69.2,25.65117
Bahamas,72.2,27.24594
Bangladesh,68.3,20.39742
Barbados,75.3,26.38439
Belarus,70.0,26.16443
Belgium,79.6,26.75915
Belize,70.7,27.02255
Benin,59.7,22.41835
Bhutan,70.7,22.82180
Bolivia,71.2,24.43335
Bosnia and Herzegovina,77.5,26.61163
Botswana,53.2,22.12984
Brazil,73.2,25.78623
Bulgaria,73.2,26.54286
Burkina Faso,58.0,21.27157
Burundi,59.1,21.50291
Cambodia,66.1,20.80496
Cameroon,56.6,23.68173
Canada,80.8,27.45210
Cape Verde,70.4,23.51522
Chad,54.3,21.48569
Chile,78.5,27.01542
China,73.4,22.92176
Colombia,76.2,24.94041
Comoros,67.1,22.06131
"Congo, Dem. Rep.",57.5,19.86692
"Congo, Rep.",58.8,21.87134
Costa Rica,79.8,26.47897
Cote d'Ivoire,55.4,22.56469
Croatia,76.2,26.59629
Cuba,77.6,25.06867
Cyprus,80.0,27.41899
Denmark,78.9,26.13287
Djibouti,61.8,23.38403
Ecuador,74.7,25.58841
Egypt,70.2,26.73243
El Salvador,73.7,26.36751
Eritrea,60.1,20.88509
Estonia,74.2,26.26446
Ethiopia,60.0,20.24700
Fiji,64.9,26.53078
Finland,79.6,26.73339
France,81.1,25.85329
French Polynesia,75.11,30.86752
Gabon,61.7,24.07620
Gambia,65.7,21.65029
Georgia,71.8,25.54942
Germany,80.0,27.16509
Ghana,62.0,22.84247
Greece,80.2,26.33786
Greenland,70.3,26.01359
Grenada,70.8,25.17988
Guatemala,71.2,25.29947
Guinea,57.1,22.52449
Guinea-Bissau,53.6,21.64338
Guyana,65.0,23.68465
Haiti,61.0,23.66302
Honduras,71.8,25.10872
Hungary,73.9,27.11568
Iceland,82.4,27.20687
India,64.7,20.95956
Indonesia,69.4,21.85576
Iran,73.1,25.31003
Iraq,66.6,26.71017
Ireland,80.1,27.65325
Israel,80.6,27.13151
Jamaica,75.1,24.00421
Japan,82.5,23.50004
Jordan,76.9,27.47362
Kazakhstan,67.1,26.29078
Kenya,60.8,21.59258
Kuwait,77.3,29.17211
Latvia,72.4,26.45693
Lesotho,44.5,21.90157
Liberia,59.9,21.89537
Libya,75.6,26.54164
Lithuania,72.1,26.86102
Luxembourg,81.0,27.43404
"Macedonia, FYR",74.5,26.34473
Madagascar,62.2,21.40347
Malawi,52.4,22.03468
Malaysia,74.5,24.73069
Maldives,78.5,23.21991
Mali,58.5,21.78881
Malta,80.7,27.68361
Marshall Islands,65.3,29.37337
Mauritania,67.9,22.62295
Mauritius,72.9,25.15669
Mexico,75.4,27.42468
Moldova,70.4,24.23690
Mongolia,64.8,24.88385
Montenegro,76.0,26.55412
Morocco,73.3,25.63182
Mozambique,54.0,21.93536
Myanmar,59.4,21.44932
Namibia,59.1,22.65008
Nepal,68.4,20.76344
Netherlands,80.3,26.01541
Nicaragua,77.0,25.77291
Niger,58.0,21.21958
Nigeria,59.2,23.03322
Norway,80.8,26.93424
Oman,76.2,26.24109
Pakistan,64.1,22.29914
Panama,77.3,26.26959
Papua New Guinea,58.6,25.01506
Paraguay,74.0,25.54223
Peru,78.2,24.77041
Philippines,69.8,22.87263
Poland,75.4,26.67380
Portugal,79.4,26.68445
Qatar,77.9,28.13138
Romania,73.2,25.41069
Russia,67.9,26.01131
Rwanda,64.1,22.55453
Samoa,72.3,30.42475
Sao Tome and Principe,66.0,23.51233
Senegal,63.5,21.92743
Serbia,74.3,26.51495
Sierra Leone,53.6,22.53139
Singapore,80.6,23.83996
Slovak Republic,74.9,26.92717
Slovenia,78.7,27.43983
Somalia,52.6,21.96917
South Africa,53.4,26.85538
Spain,81.1,27.49975
Sri Lanka,74.0,21.96671
Sudan,65.5,22.40484
Suriname,70.2,25.49887
Swaziland,45.1,23.16969
Sweden,81.1,26.37629
Switzerland,82.0,26.20195
Syria,76.1,26.91969
Tajikistan,69.6,23.77966
Tanzania,60.4,22.47792
Thailand,73.9,23.00803
Timor-Leste,69.9,20.59082
Togo,57.5,21.87875
Tonga,70.3,30.99563
Trinidad and Tobago,71.7,26.39669
Tunisia,76.8,25.15699
Turkey,77.8,26.70371
Turkmenistan,67.2,25.24796
Uganda,56.0,22.35833
Ukraine,67.8,25.42379
United Arab Emirates,75.6,28.05359
United Kingdom,79.7,27.39249
United States,78.3,28.45698
Uruguay,76.0,26.39123
Uzbekistan,69.6,25.32054
Vanuatu,63.4,26.78926
West Bank and Gaza,74.1,26.57750
Vietnam,74.1,20.91630
Zambia,51.1,20.68321
Zimbabwe,47.3,22.02660
# TODO: Add import statements
import pandas as pd
from sklearn.linear_model import LinearRegression
# Assign the dataframe to this variable.
# TODO: Load the data
bmi_life_data = pd.read_csv("bmi_and_life_expectancy.csv")
# Make and fit the linear regression model
#TODO: Fit the model and Assign it to bmi_life_model
bmi_life_model = LinearRegression()
bmi_life_model.fit(bmi_life_data[['BMI']], bmi_life_data[['Life expectancy']])
# Mak a prediction using the model
# TODO: Predict life expectancy for a BMI value of 21.07931
laos_life_exp = bmi_life_model.predict(21.07931)